{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3c5d72f4",
   "metadata": {},
   "source": [
    "# Finetune with LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6fd9cda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install transformers datasets lightning watermark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "033b75c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       : 2.1.1+cu121\n",
      "transformers: 4.36.2\n",
      "datasets    : 2.16.1\n",
      "lightning   : 2.1.2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -p torch,transformers,datasets,lightning"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "59e5aefc-40b8-4361-bc2a-6e12d25235d6",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 1 Loading the dataset into DataFrames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e39e2228-5f0b-4fb9-b762-df26c2052b45",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from datasets import load_dataset\n",
    "\n",
    "import lightning as L\n",
    "from lightning.pytorch.loggers import CSVLogger\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset\n",
    "from local_dataset_utilities import IMDBDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2ed9af91",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not torch.cuda.is_available():\n",
    "    print(\"Please switch to a GPU machine before running this notebook.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb31ac90-9e3a-41d0-baf1-8e613043924b",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = (\"test.csv\", \"train.csv\", \"val.csv\")\n",
    "download = True\n",
    "\n",
    "for f in files:\n",
    "    if not os.path.exists(os.path.join(\"data\", f)):\n",
    "        download = False\n",
    "\n",
    "if download is False:\n",
    "    download_dataset()\n",
    "    df = load_dataset_into_to_dataframe()\n",
    "    partition_dataset(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "221f30a1-b433-4304-a18d-8d03abd42b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_csv(os.path.join(\"data\", \"train.csv\"))\n",
    "df_val = pd.read_csv(os.path.join(\"data\", \"val.csv\"))\n",
    "df_test = pd.read_csv(os.path.join(\"data\", \"test.csv\"))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "846d83b1",
   "metadata": {},
   "source": [
    "# 2 Tokenization and Numericalization"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2bd5f770",
   "metadata": {},
   "source": [
    "**Load the dataset via `load_dataset`**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1aa66c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 35000\n",
      "    })\n",
      "    validation: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 5000\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 10000\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "imdb_dataset = load_dataset(\n",
    "    \"csv\",\n",
    "    data_files={\n",
    "        \"train\": os.path.join(\"data\", \"train.csv\"),\n",
    "        \"validation\": os.path.join(\"data\", \"val.csv\"),\n",
    "        \"test\": os.path.join(\"data\", \"test.csv\"),\n",
    "    },\n",
    ")\n",
    "\n",
    "print(imdb_dataset)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "029adc8f-cdfe-4386-9552-a1120f49adee",
   "metadata": {},
   "source": [
    "**Tokenize the dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5ea762ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenizer input max length: 512\n",
      "Tokenizer vocabulary size: 30522\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "print(\"Tokenizer input max length:\", tokenizer.model_max_length)\n",
    "print(\"Tokenizer vocabulary size:\", tokenizer.vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8432c15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_text(batch):\n",
    "    return tokenizer(batch[\"text\"], truncation=True, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0bb392cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6d4103c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "del imdb_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "89ef894c-978f-47f2-9d61-cb6a9f38e745",
   "metadata": {},
   "outputs": [],
   "source": [
    "imdb_tokenized.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0ea67091-aeb7-46c1-871f-638ce58d8a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4c4f8cd8-e641-45fb-9893-70677631917a",
   "metadata": {},
   "source": [
    "# 3 Set Up DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0807b068-7d8f-4055-a26a-177e07dea4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "\n",
    "class IMDBDataset(Dataset):\n",
    "    def __init__(self, dataset_dict, partition_key=\"train\"):\n",
    "        self.partition = dataset_dict[partition_key]\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.partition[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.partition.num_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "90cb08f3-ef77-4351-8b19-42d99dd24f98",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = IMDBDataset(imdb_tokenized, partition_key=\"train\")\n",
    "val_dataset = IMDBDataset(imdb_tokenized, partition_key=\"validation\")\n",
    "test_dataset = IMDBDataset(imdb_tokenized, partition_key=\"test\")\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    dataset=train_dataset,\n",
    "    batch_size=12,\n",
    "    shuffle=True, \n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "val_loader = DataLoader(\n",
    "    dataset=val_dataset,\n",
    "    batch_size=12,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    dataset=test_dataset,\n",
    "    batch_size=12,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6c977b07-2559-4ff7-85c4-d62785e81558",
   "metadata": {},
   "source": [
    "# 4 Initializing DistilBERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "10b1e449-7e00-4965-9883-f97374503996",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"distilbert-base-uncased\", num_labels=2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e9854e27-fc4f-418e-996f-a8ef04a295ac",
   "metadata": {},
   "source": [
    "**Freeze all layers**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c14e61f3-3279-4725-ad7a-3c77f85aa39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a6d3e532-3684-4dbf-a8ce-60b6c6675a00",
   "metadata": {},
   "source": [
    "**Add LoRA layers**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1bca22b5-8a48-4006-929f-66734a745a50",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistilBertForSequenceClassification(\n",
       "  (distilbert): DistilBertModel(\n",
       "    (embeddings): Embeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (transformer): Transformer(\n",
       "      (layer): ModuleList(\n",
       "        (0-5): 6 x TransformerBlock(\n",
       "          (attention): MultiHeadSelfAttention(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (ffn): FFN(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (activation): GELUActivation()\n",
       "          )\n",
       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       "  (dropout): Dropout(p=0.2, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "296f865e-68c6-4bc8-b352-350f3d7f7a14",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoRALayer(torch.nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, rank, alpha):\n",
    "        super().__init__()\n",
    "        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n",
    "        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n",
    "        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))\n",
    "        self.alpha = alpha\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.alpha * (x @ self.W_a @ self.W_b)\n",
    "        return x\n",
    "\n",
    "\n",
    "class LinearWithLoRA(torch.nn.Module):\n",
    "    def __init__(self, linear, rank, alpha):\n",
    "        super().__init__()\n",
    "        self.linear = linear\n",
    "        self.lora = LoRALayer(\n",
    "            linear.in_features, linear.out_features, rank, alpha\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x) + self.lora(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4d73e454-f0d8-411c-ba49-7cd46cf0bf64",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "lora_r = 8\n",
    "lora_alpha = 16\n",
    "lora_dropout = 0.05\n",
    "lora_query = True\n",
    "lora_key = False\n",
    "lora_value = True\n",
    "lora_projection = False\n",
    "lora_mlp = False\n",
    "lora_head = False\n",
    "\n",
    "layers = []\n",
    "\n",
    "assign_lora = partial(LinearWithLoRA, rank=lora_r, alpha=lora_alpha)\n",
    "\n",
    "for layer in model.distilbert.transformer.layer:\n",
    "    if lora_query:\n",
    "        layer.attention.q_lin = assign_lora(layer.attention.q_lin)\n",
    "    if lora_key:\n",
    "        layer.attention.k_lin = assign_lora(layer.attention.k_lin)\n",
    "    if lora_value:\n",
    "        layer.attention.v_lin = assign_lora(layer.attention.v_lin)\n",
    "    if lora_projection:\n",
    "        layer.attention.out_lin = assign_lora(layer.attention.out_lin)\n",
    "    if lora_mlp:\n",
    "        layer.ffn.lin1 = assign_lora(layer.ffn.lin1)\n",
    "        layer.ffn.lin2 = assign_lora(layer.ffn.lin2)\n",
    "if lora_head:\n",
    "    model.pre_classifier = assign_lora(model.pre_classifier)\n",
    "    model.classifier = assign_lora(model.classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f9794ef2-e074-4e13-af17-990a499b03bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistilBertForSequenceClassification(\n",
       "  (distilbert): DistilBertModel(\n",
       "    (embeddings): Embeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (transformer): Transformer(\n",
       "      (layer): ModuleList(\n",
       "        (0-5): 6 x TransformerBlock(\n",
       "          (attention): MultiHeadSelfAttention(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (q_lin): LinearWithLoRA(\n",
       "              (linear): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (lora): LoRALayer()\n",
       "            )\n",
       "            (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (v_lin): LinearWithLoRA(\n",
       "              (linear): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (lora): LoRALayer()\n",
       "            )\n",
       "            (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
       "          )\n",
       "          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (ffn): FFN(\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "            (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (activation): GELUActivation()\n",
       "          )\n",
       "          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       "  (dropout): Dropout(p=0.2, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4f852fc6-3b00-4831-a973-e03b539a6cd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "distilbert.embeddings.word_embeddings.weight: False\n",
      "distilbert.embeddings.position_embeddings.weight: False\n",
      "distilbert.embeddings.LayerNorm.weight: False\n",
      "distilbert.embeddings.LayerNorm.bias: False\n",
      "distilbert.transformer.layer.0.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.0.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.0.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.0.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.0.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.0.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.0.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.0.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.0.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.0.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.0.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.0.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.0.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.0.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.0.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.0.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.0.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.0.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.0.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.0.output_layer_norm.bias: False\n",
      "distilbert.transformer.layer.1.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.1.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.1.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.1.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.1.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.1.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.1.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.1.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.1.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.1.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.1.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.1.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.1.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.1.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.1.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.1.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.1.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.1.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.1.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.1.output_layer_norm.bias: False\n",
      "distilbert.transformer.layer.2.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.2.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.2.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.2.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.2.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.2.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.2.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.2.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.2.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.2.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.2.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.2.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.2.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.2.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.2.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.2.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.2.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.2.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.2.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.2.output_layer_norm.bias: False\n",
      "distilbert.transformer.layer.3.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.3.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.3.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.3.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.3.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.3.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.3.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.3.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.3.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.3.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.3.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.3.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.3.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.3.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.3.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.3.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.3.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.3.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.3.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.3.output_layer_norm.bias: False\n",
      "distilbert.transformer.layer.4.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.4.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.4.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.4.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.4.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.4.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.4.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.4.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.4.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.4.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.4.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.4.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.4.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.4.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.4.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.4.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.4.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.4.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.4.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.4.output_layer_norm.bias: False\n",
      "distilbert.transformer.layer.5.attention.q_lin.linear.weight: False\n",
      "distilbert.transformer.layer.5.attention.q_lin.linear.bias: False\n",
      "distilbert.transformer.layer.5.attention.q_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.5.attention.q_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.5.attention.k_lin.weight: False\n",
      "distilbert.transformer.layer.5.attention.k_lin.bias: False\n",
      "distilbert.transformer.layer.5.attention.v_lin.linear.weight: False\n",
      "distilbert.transformer.layer.5.attention.v_lin.linear.bias: False\n",
      "distilbert.transformer.layer.5.attention.v_lin.lora.W_a: True\n",
      "distilbert.transformer.layer.5.attention.v_lin.lora.W_b: True\n",
      "distilbert.transformer.layer.5.attention.out_lin.weight: False\n",
      "distilbert.transformer.layer.5.attention.out_lin.bias: False\n",
      "distilbert.transformer.layer.5.sa_layer_norm.weight: False\n",
      "distilbert.transformer.layer.5.sa_layer_norm.bias: False\n",
      "distilbert.transformer.layer.5.ffn.lin1.weight: False\n",
      "distilbert.transformer.layer.5.ffn.lin1.bias: False\n",
      "distilbert.transformer.layer.5.ffn.lin2.weight: False\n",
      "distilbert.transformer.layer.5.ffn.lin2.bias: False\n",
      "distilbert.transformer.layer.5.output_layer_norm.weight: False\n",
      "distilbert.transformer.layer.5.output_layer_norm.bias: False\n",
      "pre_classifier.weight: False\n",
      "pre_classifier.bias: False\n",
      "classifier.weight: False\n",
      "classifier.bias: False\n"
     ]
    }
   ],
   "source": [
    "# Check if linear layers are frozen\n",
    "for name, param in model.named_parameters():\n",
    "    print(f\"{name}: {param.requires_grad}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ddc77229-2ab1-4d9a-b7a6-d37610e00ee4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trainable parameters: 147456\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "\n",
    "print(\"Total number of trainable parameters:\", count_parameters(model))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "83971bc4-143d-4d82-b5cb-8774f9112f53",
   "metadata": {},
   "source": [
    "# 5 Finetuning"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "730b9a87-709e-4f3d-9c9f-475ccbb56908",
   "metadata": {},
   "source": [
    "**Wrap in LightningModule for Training**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "9f2c474d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from local_model_utilities import CustomLightningModule\n",
    "\n",
    "lightning_model = CustomLightningModule(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e6dab813-e1fc-47cd-87a1-5eb8070699c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = [\n",
    "    ModelCheckpoint(\n",
    "        save_top_k=1, mode=\"max\", monitor=\"val_acc\"\n",
    "    )  # save top 1 model\n",
    "]\n",
    "logger = CSVLogger(save_dir=\"logs/\", name=\"my-model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "492aa043-02da-459e-a266-091b34254ac6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using 16bit Automatic Mixed Precision (AMP)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = L.Trainer(\n",
    "    max_epochs=3,\n",
    "    callbacks=callbacks,\n",
    "    accelerator=\"gpu\",\n",
    "    precision=\"16-mixed\",\n",
    "    devices=1,\n",
    "    logger=logger,\n",
    "    log_every_n_steps=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ff08ea9a-faf6-410d-9a2f-a0d2b92dd027",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "Missing logger folder: logs/my-model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name     | Type                                | Params\n",
      "-----------------------------------------------------------------\n",
      "0 | model    | DistilBertForSequenceClassification | 67.1 M\n",
      "1 | val_acc  | MulticlassAccuracy                  | 0     \n",
      "2 | test_acc | MulticlassAccuracy                  | 0     \n",
      "-----------------------------------------------------------------\n",
      "147 K     Trainable params\n",
      "67.0 M    Non-trainable params\n",
      "67.1 M    Total params\n",
      "268.410   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8913b22260bd4a27bc3c52ddda776e3c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36eeecda87d04728866eb6c94c4ce002",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfaa4e9a1c7742dea678abc0de88b284",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14a6fe2ca58c440b94e789135d0728e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16aa2e6f7b3a4bc3a958a4949705c682",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time elapsed 13.27 min\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "start = time.time()\n",
    "\n",
    "trainer.fit(model=lightning_model,\n",
    "            train_dataloaders=train_loader,\n",
    "            val_dataloaders=val_loader)\n",
    "\n",
    "end = time.time()\n",
    "elapsed = end - start\n",
    "print(f\"Time elapsed {elapsed/60:.2f} min\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d795778a-70d2-4b04-96fb-598eccbcd1be",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n",
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09a77860fcf24ee7b1d14892f205d0f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "91c555dcdce449898c9b48aea2147696",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loaded model weights from the checkpoint at logs/my-model/version_0/checkpoints/epoch=0-step=2917.ckpt\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5909c587ff14a35ba8d605993d69d3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Testing: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_acc = trainer.test(lightning_model, dataloaders=train_loader, ckpt_path=\"best\", verbose=False)\n",
    "val_acc = trainer.test(lightning_model, dataloaders=val_loader, ckpt_path=\"best\", verbose=False)\n",
    "test_acc = trainer.test(lightning_model, dataloaders=test_loader, ckpt_path=\"best\", verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9f211076-cfea-4779-a05c-eb8073ee41dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train acc: 90.62%\n",
      "Val acc:   89.76%\n",
      "Test acc:  88.71%\n"
     ]
    }
   ],
   "source": [
    "print(f\"Train acc: {train_acc[0]['accuracy']*100:2.2f}%\")\n",
    "print(f\"Val acc:   {val_acc[0]['accuracy']*100:2.2f}%\")\n",
    "print(f\"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "a3855d32-906b-46b9-b536-61d6cbefcb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "# Cleanup checkpoint files as we don't need them later\n",
    "log_dir = f\"logs/my-model\"\n",
    "if os.path.exists(log_dir):\n",
    "    shutil.rmtree(log_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
